# model_age_both

#### Set-Up ####

rm(list = ls())
Sys.info()[7]

# Set file paths 
DATA <- "../external_links/real"
TEMP <- "../temp"
RESULTS <- "../output"

# Go to code folder
setwd("./")

sink("estimate_xs.txt")

folder <- paste0(TEMP, "/estimate_w_xs")
if (file.exists(folder)) {
  cat("The folder already exists")
} else {
  dir.create(folder)
}
folder <- paste0(RESULTS, "/estimate_w_xs")
if (file.exists(folder)) {
  cat("The folder already exists")
} else {
  dir.create(folder)
}

Sys.info()[7]
Sys.time()

# Packages
library(MASS)
library(tmvtnorm)
library(haven)
library(dplyr)
library(tidyr)
library(nloptr)
library(corpcor)
library(stargazer)
library(pracma)
library(foreign)
library(writexl)
library(broom)
library(png)
library(tidyverse)
# new packages:
library(compositions) # for rlnorm.rplus

#### Load Data + Assumptions ####
Sys.info()[7]

# analysis <- read_dta(paste0(TEMP, "/estimate_w_xs/model_sample_recon.dta")) %>%
analysis <- read_dta(paste0(DATA, "/model_sample.dta")) %>%
  mutate(oop_i = ifelse(is.na(oop_i)==TRUE, Inf, oop_i))

print(DATA)
pa <- 0.005
pfp <- 0.01
pfn <- 0.01

J <- 1 
#file.remove(paste0(TEMP, "/Prenatal/estimate_w_xs/estimates.csv"))
#file.remove(paste0(RESULTS, "/Prenatal/estimate_w_xs/estimates.csv"))

#### Load Necessary Model Functions ####
source("r_functions/production_model_xs_new.R")

#### ESTIMATE ####
print("Estimate on actual data")

# Calculate actual moments in data + their weights
actual <- moments_wgts(data=analysis)
  data_moments <- actual[[1]]
  weights <- actual[[2]]
  rm(actual)
print("calculated actual moments")
data_moments_table <- data.frame(data_moments, weights)
write.csv(data_moments_table, file=paste0(RESULTS, "/estimate_w_xs/data_moments_table.csv"), na=".", row.names=TRUE)

# print("table(analysis$did_nipt, analysis$wave)")
# table(analysis$did_nipt, analysis$wave)

# Define objective functions
target = function(x) obj(x, analysis, data_moments, weights, f_a = f_a, f_c = f_c, x_vars = x_vars, unique_flag=1, info=FALSE)
g_target = function(x) nl.grad(x, target)

# To use parameters inclusive of mean shifting:
# Using f_a = ~ a_i + age
# and   f_c = ~ c_i + age
# Means TWO new parameters need to be added to the end of this 
# vector below. The 2.0 will be the coefficient on age in f_a,
# and the -3.0 will be the coefficient on age in f_c.

# Bounds when using the mean shifter! 
# when we add a new parameter, add in upper and lower bounds for the parameter
# for categorical variables, verify the number of distinct values, and add that number of elements to the bounds
# (mu_a, mu_c, se_a, se_c, rho, psi, age_a, age_c)
# lb = c(-300000, -300000, 1, 1, -0.9, 0.5, -999999, -999999)
# ub = c(0, 0, 100000, 100000, 0.9, 1, 999999, 999999)
# par0 <- c(-126772.04, -201776.45, 1, 61968.99,  0.8795, 0.9569, 0, 0)

# lb = c(0, 0, 0.1, 0.1, -0.9, 0, -15, -15)
# ub = c(15, 15, 10, 10, 0.9, 1, 15, 15)
# par0 <- c(4, 4, 4, 4, 0.5, 0.5, 0, 0)

# candidate
lb = c(0, 0, 0.1, 0.1, -0.9, 0, -2, -2)
ub = c(15, 15, 10, 10, 0.9, 1, 2, 2)
par0 <- c(4, 4, 4, 4, 0.5, 0.5, 0, 0)

# # testing age shifters
# # 11.78955 12.20313 0.4371086 0.5564466 0.8151445 0.9203451 0 0
# lb = c(11.78955, 12.20313, 0.4371086, 0.5564466, 0.8151445,0.9203451, -1, -1)
# ub = c(11.78955, 12.20313, 0.4371086, 0.5564466, 0.9203451, 0.9203451, 1, 1)
# par0 <- c(11.78955, 12.20313, 0.4371086, 0.5564466, 0.8151445, 0.9203451, 0, 0)


divider <- c("####################")
divider <- t(divider)

# Wrapper function to use a global method first, follower by a local gradient method.
# alg1 and alg2 may be any algorithm string supported by NLOPTR.
# normally, o1 maxtime == 2419200
doublesolve <- function(
  x0,
  eval_f,
  eval_grad=NA,
  lb,
  ub,
  alg1="NLOPT_GN_DIRECT_NOSCAL",
  alg2="NLOPT_LN_SBPLX"
) {
  o1 = list("maxeval"=10000, # used to be maxtime 64800 maxeval 10000000
            "xtol_rel"=0,
            "algorithm"=alg1,
            "pop.size" = (length(x0) + 1) * 100, # If convergence is too slow, can
            # lower the 100 to ~10ish
            "print_level"=3, 
            "ranseed" = 4792
  )

  o2 = list("maxeval"=100000, # originally 100000
            "xtol_rel"=1.0e-8,
            "algorithm"=alg2,
            "print_level"=3,
            "ranseed" = 2837
  )
  set.seed(101)
  write.table(divider, file=paste0(TEMP, "/estimate_w_xs/estimates.csv"), sep = ",", na=".", append = TRUE, col.names = FALSE, row.names = FALSE)
  write.table(divider, file=paste0(RESULTS, "/estimate_w_xs/estimates.csv"), sep = ",", na=".", append = TRUE, col.names = FALSE, row.names = FALSE)
  s1 = nloptr(x0=x0, eval_f=eval_f, lb=lb, ub=ub, opts=o1)
  print("Results of Global optimization")
  print(s1)
 
  write.table(divider, file=paste0(TEMP, "/estimate_w_xs/estimates.csv"), sep = ",", na=".", append = TRUE, col.names = FALSE, row.names = FALSE)
  write.table(divider, file=paste0(RESULTS, "/estimate_w_xs/estimates.csv"), sep = ",", na=".", append = TRUE, col.names = FALSE, row.names = FALSE)
  s2 = nloptr(x0=s1$solution, eval_f=eval_f, lb=lb, ub=ub, opts=o2)
  print("Results of Local optimization")
  print(s2)

  return(list(s1, s2))
}

####### Normalize X's variables #############
analysis <- analysis %>%
  mutate(age_norm = (age - mean(age)) / sd(age))
print("age stats:")
print(mean(analysis$age))
print(sd(analysis$age))
print(mean(analysis$age_norm))
print(sd(analysis$age_norm))
f_a = ~ 0 + age_norm
f_c = ~ 0 + age_norm
x_vars <- c("age_norm")

###### Optimize ####
system.time(estimate <- doublesolve(par0, target, g_target, lb, ub))
global <- estimate[1]
local <- estimate[2]
global_list <- global[[1]]
global_est <- global_list$solution
global_est <- as.data.frame(global_est)
write.table(global_est, file = paste0(RESULTS, "/estimate_w_xs/estimates_global_no_xs.csv"), sep = ",", col.names = FALSE, row.names = FALSE)

# Extract results
est <- local[[1]]
results_1 <- est$solution
opt_est <- est$solution
results_1 <- append(results_1, est$objective)
results_1 <- append(results_1, est$iterations)
results_1 <- as.data.frame(results_1)
write_csv(results_1, file = paste0(RESULTS, "/estimate_w_xs/results_no_xs.csv"))
rm(estimate)

# Extract some diagnostics from local
set.seed(4793285)
ans_local = obj(est$solution, analysis, data_moments, weights, f_a = f_a, f_c = f_c, x_vars = x_vars, info=TRUE)

#####################################
# Construct local opt moment comparison table #
#####################################
moments_local = data.frame(name = ans_local$moment_names,
                     model_moments = ans_local$model,
                     data_moments = ans_local$data,
                     weight = weights)

# # Write the analysis table to disk
write.csv(moments_local, file=paste0(RESULTS, "/estimate_w_xs/moments_local_opt_global2.csv"), na=".", row.names=TRUE)

### TEMPORARILY CHANGE EST_PAR DEFINITION
opt_est <- est$solution
par_est <- opt_est

write.csv(par_est, file=paste0(TEMP, "/estimate_w_xs/nipt_estimates_rec.csv"), na=".", row.names=TRUE)
write.csv(par_est, file=paste0(RESULTS, "/estimate_w_xs/nipt_estimates_rec.csv"), na=".", row.names=TRUE)

set.seed(101)
actual <- analysis %>%
  dplyr::select(pregnancy, bin_number, p_i, fetus_risk, wave, did_nipt, did_invasive, oop_i, policy_id, policy_regime, all_of(x_vars))
simulated <- model_decisions(par=par_est, data=actual, f_a = f_a, f_c = f_c, x_vars = x_vars, unique_flag = 1) %>%
  as.data.frame() %>%
  dplyr::select(a_i, c_i, pred_nipt, pred_invasive)

# altogether
all <- cbind(actual, simulated)

temp<- actual

### simulate 100 draws per pregnancy
K <- 100
actual_1 <- actual
if(K > 1){
  for (i in 1:(K-1)) {
    actual <- rbind(actual, actual_1)
  }
}
rm(i, actual_1)

set.seed(101)
simulated_100 <- model_decisions(par=par_est, data=actual, f_a = f_a, f_c = f_c, x_vars = x_vars, unique_flag = 1) %>%
  as.data.frame() %>%
  dplyr::select(a_i, c_i, pred_nipt, pred_invasive)

all_100 <- cbind(actual, simulated_100)

# create X*betas
betas <- par_est[7:8]
print(betas) 
a_beta <- betas[1:1]
print(a_beta)
c_beta <- betas[2:2]
print(c_beta)
xs_only <- all %>%
  select(all_of(x_vars))
x_beta_a <- data.frame(mapply("*", xs_only, a_beta))
colnames(x_beta_a) <- c("age_norm_a")
x_beta_c <- data.frame(mapply("*", xs_only, c_beta))
colnames(x_beta_c) <- c("age_norm_c")
x_beta <- cbind(x_beta_a, x_beta_c)

risk_plot_output <- cbind(all, x_beta) %>%
  mutate(sum_xbeta_a = age_norm_a) %>%
  mutate(sum_xbeta_c = age_norm_c) 

rm(actual, simulated)
write.csv(all, file=paste0(TEMP, "/estimate_w_xs/model_predict_disagg_no_xs.csv"), na=".", row.names=TRUE)
write.csv(all, file=paste0(RESULTS, "/estimate_w_xs/model_predict_disagg_no_xs.csv"), na=".", row.names=TRUE)

write.csv(risk_plot_output, file=paste0(RESULTS, "/estimate_w_xs/risk_plot_no_xs.csv"), na=".", row.names=TRUE)
write.csv(all_100, file=paste0(RESULTS, "/estimate_w_xs/risk_plot_100_no_xs.csv"), na=".", row.names=TRUE)


#### EXIT ####
print("finished estimate_xs!")
sink()
#rm(list = ls())
